# -*- coding: utf-8 -*-
"""
Created on Tue Sep 20 13:42:34 2022

@author: obiri
"""

import mpctools as mpc
import casadi
import numpy as np
import matplotlib.pyplot as plt
import time
import sys
from scipy import linalg
from numpy import random

from MHE_model_normalised import reactor_tank
from MHE_model_normalised import measurement_xy_model
time.clock = time.time



random.seed(927) # Seed random number generator.

doPlots = True                      #Q: what do these mean?
                                    #Ans:Runs the plot functions for a given data set and saves them all in graphics files
                            
fullInformation = False # True for full information estimation, False for MHE.   #Q: what do these mean?


Nt = 50 #window
# Delta = 60 # Time step
Delta = 60 # Time step
Nsim = 50
tplot = np.arange(Nsim+1)*Delta 


# Nx = 19   # Number of system states
Nx = 16   # Number of system states
Nu = 8   # Number of system inputs
 
Ny = 7  # Number of system outputs        #Q:Meaning out of the 2 states, we're measuring only 1 state?
Nw = Nx  # Number of process disturbances
Nv = Ny  # Number of measurement noise


## Make covariance matrices.
# sigma_v = 0*0.0035 # 0.001 Standard deviation of the measurements
# sigma_w = 0.003 #0.001 Standard deviation for the process noise
# sigma_p = 50 #0.05 # Standard deviation for prior   #Q: what is 'prior' here?
sigma_p = 50

##varying the S.D of the noise for each of the state estimates

# sigma_v = 0.0095*np.array([0.05,  0.05,  0.05, 0.009, 0.009, 0.0095, 0.009, 0.0095, 0.009])
sigma_v = 0.0095*np.array([0.05,  0.05,  0.05, 0.009, 0.009, 0.0095, 0.009])
# sigma_w = 0.5*np.array([0.05, 1e7, 0.08, 0.0005, 0.005, 0.0008, 0.003, 0.005, 1e6, 0.015, 0.001, 0.003, 0.00095, 0.005, 0.003, 0.0001]) 
sigma_w = 0.5*np.array([5*1e6, 0.5*1e5, 0.9, 0.0005, 0.005, 0.0008, 0.003, 0.005, 1e6, 0.015, 0.0004, 0.003, 0.00075, 0.005, 0.003, 0.0001])

# # Make covariance matrices.
# R = np.diag((sigma_v*np.ones((Nv,)))**2) + 500*np.eye((Nv))
# Q = np.diag((sigma_w*np.ones((Nw,)))**2) +500*np.eye((Nw))

# # Make covariance matrices.
R = np.diag((sigma_v*np.ones((Nv,)))**2) 
Q = np.diag((sigma_w*np.ones((Nw,)))**2) 
P = np.diag((sigma_p*np.ones((Nx,)))**2) # Covariance for prior.
# P = np.eye((Nx))

#FOR 19 STATES
# x0 = np.array([3.22456499e+10, 8.74344999e+10, 2.21584610e+03, 3.56206935e+00,

#         6.72674220e+01, 5.03509229e+00, 6.80219613e+01, 2.99999103e+03,
#         2.58024685e+09, 6.99637296e+09, 1.77308565e+03, 2.85031259e+00,
#         5.38263469e+01, 4.02900269e+00, 5.44301175e+01, 2.97694159e+03,
#         3.69613398e+01, 1.50000000e+01, 5.44301175e+01])

#FOR 16 STATES
x0 = np.array([3.22456499e+10, 8.74344999e+10, 2.21584610e+03, 3.56206935e+00,
        6.72674220e+01, 5.03509229e+00, 6.80219613e+01, 
        2.58024685e+09, 6.99637296e+09, 1.77308565e+03, 2.85031259e+00,
        5.38263469e+01, 4.02900269e+00, 5.44301175e+01, 
        3.69613398e+01, 5.44301175e+01])
#ORIGINAL X0 VALUES USED-0.5 WAS THE NORMALISED STEADY-STATE FOR ALL THE STATES
x0_norm =1.15 *np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.3, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.3, 0.25, 0.5]) #X0 FOR ACTUAL, NORMALISED VALUES
# x0_norm =1.15 *np.array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
x_0 = 1.2*x0_norm #guess for MHE
#CONVERTING THE X0 ABOVE TO NORMALISED VALUES. NB:THE X0 ABOVE(ACTUAL) REPRESENTS A DIFFERENT STEADY-STATE .
# x0_norm = np.array([0.484999, 0.554913, 0.315424, 0.115631, 0.554916, 0.542165, 0.590329, 0.484975, 0.554888, 0.315405, 0.115617, 0.554891, 0.54214, 0.590303, 0.498012, 0.590303]) 
# x_0 = 1.1*x0_norm



#P = np.outer((x0_norm - x_0), (x0_norm - x_0)) 


u = np.zeros((Nsim,Nu))
us=np.array([2.16100550e+01, 2.16155565e+01, 5.50143837e-03, 2.16100550e+01,
        2.19999850e+03, 1.30642193e+01, 4.06999944e+01, 2.16100550e+01])

## --supplyng the same inputs every step.--##
# for kk in range(Nsim):    
#     u[kk,0] = us[0]
#     u[kk,1] = us[1]
#     u[kk,2] = us[2]
#     u[kk,3] = us[3]
#     u[kk,4] = us[4]
#     u[kk,5] = us[5]
#     u[kk,6] = us[6]
#     u[kk,7] = us[7]


#VARYING THE FLOWRATE TO INCREASE THE OSCILLATIONS IN THE PLOTS
us0=0.98*us
us1=0.99*us
us2=1.01*us
us3=1*us
us4=1*us
us_o = np.array([us0, us1, us2, us3, us4])
for i in range(us_o.shape[1]):
    u[0:10, i] = us_o[0,i]
    u[10:20,i] = us_o[1,i]
    u[20:30,i] = us_o[2,i]
    u[30:40,i] = us_o[3,i]
    u[40:50,i] = us_o[4,i]



#testing in the console
# import numpy as np
# Nsim =50
# Nu=8
# u = np.zeros((Nsim,Nu))
# us=np.array([2.16100550e+01, 2.16155565e+01, 5.50143837e-03, 2.16100550e+01,
#         2.19999850e+03, 1.30642193e+01, 4.06999944e+01, 2.16100550e+01])
# us0=0.95*us
# us1=0.98*us
# us2=1*us
# us3=1.1*us
# us4=1.12*us
# us5=1.12*us
# us6=1.12*us
# us7=1.05*us
# us8=1.0*us
# us9=1.0*us

# us_o = np.array([us0, us1, us2, us3, us4, us5, us6, us7, us8, us9])
# #nb:last index is left out
# for i in range(us_o.shape[1]):
#     u[0:5, i]  = us_o[0,i]
#     u[5:10, i] = us_o[1,i]
#     u[10:15,i] = us_o[2,i]
#     u[15:20,i] = us_o[3,i]
#     u[20:25,i] = us_o[4,i]
#     u[25:30,i] = us_o[5,i]
#     u[30:35,i] = us_o[6,i]
#     u[35:40,i] = us_o[7,i]
#     u[40:45,i] = us_o[8,i]
#     u[45:50,i] = us_o[9,i]
  
    



# Make a simulator.
model_reactor_simulator = mpc.DiscreteSimulator(reactor_tank, Delta, [Nx,Nu,Nw], ["x","u","w"])    

# Convert continuous-time f to explicit discrete-time F with RK4.
F = mpc.getCasadiFunc(reactor_tank,[Nx,Nu,Nw],["x","u","w"],"F",rk4=True,Delta=Delta,M=1)
#F = mpc.getCasadiFunc(reactor_tank,   [Nx, Nu], ['x','u'], funcname="odef")
H = mpc.getCasadiFunc(measurement_xy_model,[Nx],["x"],"H")


# Define stage costs.
def lfunc(w,v):
    return mpc.mtimes(w.T,linalg.inv(Q),w)+mpc.mtimes(v.T,linalg.inv(R),v)
l = mpc.getCasadiFunc(lfunc,[Nw,Nv],["w","v"],"l")
def lxfunc(x):
    return mpc.mtimes(x.T,linalg.inv(P),x)
lx = mpc.getCasadiFunc(lxfunc,[Nx],["x"],"lx")



sigma_v2 = 0.0095*np.array([0.05,  0.05,  0.05, 0.009, 0.009, 0.0095, 0.009])
# sigma_wa = 0.5*np.array([5*1e6, 0.5*1e5, 0.1, 0.0005, 0.005, 0.0008, 0.003, 0.005, 1e6, 0.015, 0.0004, 0.003, 0.00075, 0.005, 0.003, 0.0001])
sigma_wa = 0.5*np.array([5*1e6, 0.5*1e5, 0.9, 0.0005, 0.005, 0.0008, 0.003, 0.005, 1e6, 0.015, 0.0004, 0.003, 0.00075, 0.005, 0.003, 0.0001])

w = sigma_wa*random.randn(Nsim,Nw) #defining the measurement noise
v = sigma_v2*random.randn(Nsim,Nv) #definig the process noise


# w = sigma_w*random.randn(Nsim,Nw) #defining the measurement noise
# v = sigma_v*random.randn(Nsim,Nv) #defining the process noise


usim = u # We use the input vector to the process instead of dummy input
xsim = np.zeros((Nsim+1,Nx))
# xsim[0,:] = x0
xsim[0,:] = x0_norm
yclean = np.zeros((Nsim, Ny))
ysim = np.zeros((Nsim, Ny))


# Simulate the process dynamics
for t in range(Nsim):
    yclean[t,:] = measurement_xy_model(xsim[t]) # Get zero-noise measurement. 
    ysim[t,:] = yclean[t,:] + v[t,:] # Adding measurement noise to the measurement
    xsim[t+1,:] = model_reactor_simulator.sim(xsim[t,:],usim[t,:],w[t,:]) #Adding process noise to the states


# Now do estimation.
xhat_ = np.zeros((Nsim+1,Nx))   #Q:what is the difference btwn xhat_ and xhat?
xhat = np.zeros((Nsim,Nx))
yhat = np.zeros((Nsim,Ny))
vhat = np.zeros((Nsim,Nv))      #Q:You estimate the noises too?
what = np.zeros((Nsim,Nw))


x0bar = x_0
xhat[0,:] = x0bar
guess = {}



solveroptions = {
            'linear_solver':'mumps' 
            }
        
totaltime = -time.time()         #Q:why -time??
for t in range(1, Nsim):
    # Define sizes of everything.    
    N = {"x":Nx, "y":Ny, "u":Nu}
    if fullInformation:          #Q: what does this if else statement mean?
        N["t"] = t
        tmin = 0
    else:
        N["t"] = min(t,Nt)
        tmin = max(0,t - Nt)
    tmax = t+1        
    #lb = {"x":np.zeros((N["t"] + 1,Nx))} 
    
    #UPPER AND LOWER BOUND OF THE STATES:
    ub_val = 1*np.ones(Nx)
    # ub_val[7] =0.34
    # ub_val[17] =0.57
    lb_val = 0*np.ones(Nx)
    # lb_val[7] =0.3
    
    
    lb = {'x': lb_val}
    ub = {'x': ub_val}
    

    buildtime = -time.time()   
    solver = mpc.nmhe(f=F, h=H, u=usim[tmin:tmax-1,:],  #Q:use mpc to solve mhe?
                      y=ysim[tmin:tmax,:], l=l, N=N, 
                      verbosity=0,
                      lb=lb, ub=ub, guess=guess,Delta=Delta)
    buildtime += time.time()
    solvetime = -time.time()
    sol = mpc.callSolver(solver)
    solvetime += time.time()
    print ("%3d (%5.3g s build, %5.3g s solve): %s"
           % (t, buildtime, solvetime, sol["status"]))
    if sol["status"] != "Solve_Succeeded":
        break
    xhat[t,:] = sol["x"][-1,...] # This is xhat( t  | t )
    yhat[t,:] = measurement_xy_model(xhat[t,:])    
    vhat[t,:] = sol["v"][-1,...]
    if t > 0:
        what[t-1,:] = sol["w"][-1,...]
    
    # Apply model function to get xhat(t+1 | t )
    xhat_[t+1,:] = np.squeeze(F(xhat[t,:], usim[t,:], np.zeros((Nw,))))
    
    # Save stuff to use as a guess. Cycle the guess.
    guess = {}
    for k in set(["x","w","v"]).intersection(sol.keys()):
        guess[k] = sol[k].copy()
    
    # Do some different things if not using full information estimation.    
    if not fullInformation and t + 1 > Nt:
        for k in guess.keys():
            guess[k] = guess[k][1:,...] # Get rid of oldest measurement.


#============================================================================#            
        # Do EKF to update prior covariance, but don't take EKF state. Remove if arrival cost is not needed
        [P, x0bar, _, _] = mpc.ekf(F,H,x=sol["x"][0,...],
            u=usim[tmin,:],w=sol["w"][0,...],y=ysim[tmin,:],P=P,Q=Q,R=R)
        
        # Need to redefine arrival cost.
        def lxfunc(x):
            return mpc.mtimes(x.T,linalg.inv(P),x)
        lx = mpc.getCasadiFunc(lxfunc,[Nx],["x"],"lx")
#============================================================================# 
     # Add final guess state for new time point.
    for k in guess.keys():
        guess[k] = np.concatenate((guess[k],guess[k][-1:,...]))

totaltime += time.time()
print ( "Simulation took %.5g s." % totaltime )


x_actual_hat = np.zeros((Nsim,Nx))
x_actual = np.zeros((Nsim,Nx))


# xs = np.array([3.27367434e+10, 8.28831113e+10, 2.71741538e+03, 5.78604334e+00,
#         6.37656737e+01, 4.83137928e+00, 6.23866477e+01, 3.00039503e+03,
#         2.61960613e+09, 6.63233551e+09, 2.17448574e+03, 4.63001301e+00,
#         5.10255242e+01, 3.86608729e+00, 4.99220220e+01, 2.97694060e+03,
#         3.70349577e+01, 1.50000000e+01, 4.99220218e+01])

xs = np.array([3.27367434e+10, 8.28831113e+10, 2.71741538e+03, 5.78604334e+00,
        6.37656737e+01, 4.83137928e+00, 6.23866477e+01, 
        2.61960613e+09, 6.63233551e+09, 2.17448574e+03, 4.63001301e+00,
        5.10255242e+01, 3.86608729e+00, 4.99220220e+01, 
        3.70349577e+01, 4.99220218e+01])

# # scaling parameters for the ODE -- x = x_scale * delta + x_min

lb = 0.5 # lower bound of normalized model: lb*xs
ub = 1.5 # upper bound of normalized model: ub*xs
width = ub-lb
#delta = (ub-lb)*xs = width*xs
delta = np.array([width*xs[0], width*xs[1], width*xs[2], width*xs[3], width*xs[4], width*xs[5],
                  width*xs[6], width*xs[7], width*xs[8], width*xs[9], width*xs[10], width*xs[11],
                  width*xs[12], width*xs[13], width*xs[14], width*xs[15]])

# minvalue lb*xs
minvalue = np.array([lb*xs[0], lb*xs[1], lb*xs[2], lb*xs[3], lb*xs[4], lb*xs[5],
                     lb*xs[6], lb*xs[7], lb*xs[8], lb*xs[9], lb*xs[10], lb*xs[11],
                     lb*xs[12],lb*xs[13],lb*xs[14],lb*xs[15]])

# Recover the actual states and estimates based on the scaled states and estimates
for i in range(Nx):
    x_actual_hat [:,i] = xhat [:,i] * delta[i] + minvalue[i]
    x_actual [:,i] = xsim [:Nsim,i] * delta[i] + minvalue[i]    #Q:what's the 400 here for? Do I have to change it?


np.savetxt("x_from_py.txt",x_actual)      # this saves the data of x
np.savetxt("estimate_from_mhe.txt",x_actual_hat)      # this saves the data of the MHE estimate

np.save("x_actual.npy",x_actual)      # this saves the data of x
np.save("x_hat.npy",x_actual_hat)  




##plots##

fig1, axs = plt.subplots(4, 2)
axs[0, 0].plot(tplot[:-1], x_actual[:,0], 'tab:blue', linewidth=2)
axs[0, 0].plot(tplot[:-1], x_actual_hat[:,0], '--r', linewidth=2)
axs[0, 0].set_ylabel('Xv_1 (cell/L)')

axs[0, 1].plot(tplot[:-1], x_actual[:,1], 'tab:blue', linewidth=2)
axs[0, 1].plot(tplot[:-1], x_actual_hat[:,1],  '--r' , linewidth=2)
axs[0, 1].set_ylabel('Xt1 (cell/L)')

axs[1, 0].plot(tplot[:-1], x_actual[:,2], 'tab:blue' , linewidth=2)
axs[1, 0].plot(tplot[:-1], x_actual_hat[:,2],  '--r')
axs[1, 0].set_ylabel('GLC1 (mM)')

axs[1, 1].plot(tplot[:-1], x_actual[:,3], 'tab:blue', linewidth=2)
axs[1, 1].plot(tplot[:-1], x_actual_hat[:,3],  '--r', linewidth=2)
axs[1, 1].set_ylabel('GLN1 (mM)')

axs[2, 0].plot(tplot[:-1], x_actual[:,4], 'tab:blue', linewidth=2)
axs[2, 0].plot(tplot[:-1], x_actual_hat[:,4],  '--r', linewidth=2)
axs[2, 0].set_ylabel('LAC1 (mM)')

axs[2, 1].plot(tplot[:-1], x_actual[:,5], 'tab:blue', linewidth=2)
axs[2, 1].plot(tplot[:-1], x_actual_hat[:,5],  '--r')
axs[2, 1].set_ylabel('AMM1 (mM)')

axs[3, 0].plot(tplot[:-1], x_actual[:,6], 'tab:blue', linewidth=2)
axs[3, 0].plot(tplot[:-1], x_actual_hat[:,6],  '--r')
axs[3, 0].set_ylabel('mAb1 (mg/L)')

axs[3, 1].plot(tplot[:-1], x_actual[:,7], 'tab:blue', linewidth=2)
axs[3, 1].plot(tplot[:-1], x_actual_hat[:,7], '--r', linewidth=2)
axs[3, 1].set_ylabel('Xv2 (cell/L)')

fig1.legend(['Actual', 'Estimate'])
plt.rcParams["figure.figsize"] = [8.00, 7.00]
plt.setp(axs[-1, :], xlabel='Time(min)')




fig2, axs = plt.subplots(4, 2)
axs[0, 0].plot(tplot[:-1], x_actual[:,8], 'tab:blue', linewidth=2)
axs[0, 0].plot(tplot[:-1], x_actual_hat[:,8], '--r')
axs[0, 0].set_ylabel('Xt2 (cell/L)')

axs[0, 1].plot(tplot[:-1], x_actual[:,9], 'tab:blue', linewidth=2)
axs[0, 1].plot(tplot[:-1], x_actual_hat[:,9], '--r', linewidth=2)
axs[0, 1].set_ylabel('GLC2 (mM)')

axs[1, 0].plot(tplot[:-1], x_actual[:,10], 'tab:blue', linewidth=2)
axs[1, 0].plot(tplot[:-1], x_actual_hat[:,10], '--r')
axs[1, 0].set_ylabel('GLN2 (mM)')

axs[1, 1].plot(tplot[:-1], x_actual[:,11], 'tab:blue',linewidth=2)
axs[1, 1].plot(tplot[:-1], x_actual_hat[:,11], '--r', linewidth=2)
axs[1, 1].set_ylabel('LAC2 (mM)')


axs[2, 0].plot(tplot[:-1], x_actual[:,12], 'tab:blue', linewidth=2)
axs[2, 0].plot(tplot[:-1], x_actual_hat[:,12], '--r', linewidth=2)
axs[2, 0].set_ylabel('AMM2 (mM)')

axs[2, 1].plot(tplot[:-1], x_actual[:,13], 'tab:blue', linewidth=2)
axs[2, 1].plot(tplot[:-1], x_actual_hat[:,13], '--r', linewidth=2)
axs[2, 1].set_ylabel('mAb2 (mg/L)')

axs[3, 0].plot(tplot[:-1], x_actual[:,14], 'tab:blue', linewidth=2)
axs[3, 0].plot(tplot[:-1], x_actual_hat[:,14], '--r')
axs[3, 0].set_ylabel('T (d C)')

axs[3, 1].plot(tplot[:-1], x_actual[:,15], 'tab:blue', linewidth=2)
axs[3, 1].plot(tplot[:-1], x_actual_hat[:,15], '--r', linewidth=2)
axs[3, 1].set_ylabel('c_buffer (mg/L)')

fig2.legend(['Actual', 'Estimate'])
plt.rcParams["figure.figsize"] = [8.00, 7.00]
plt.setp(axs[-1, :], xlabel='Time(min)')










